{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "deep_transfer_tutorial.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.5"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BOlo4Ctmtm4p"
      },
      "source": [
        "# Deep transfer learning tutorial\n",
        "This notebook contains two popular paradigms of transfer learning: **Finetune** and **Domain adaptation**.\n",
        "Since most of the codes are shared by them, we show how they work in just one single notebook.\n",
        "I think that transfer learning and domain adaptation are both easy, and there's no need to create some library or packages for this simple purpose, which only makes things difficult. \n",
        "The purpose of this note book is we **don't even need to install a library or package** to train a domain adaptation or finetune model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9P9av5SNtm4r"
      },
      "source": [
        "## Preparation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mcjr5uA7tm4r"
      },
      "source": [
        "First of all, install `pytorch` and `torchvision`. \n",
        "Skip this step if you already installed."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "f3rx2L6hSvGY",
        "outputId": "9ea41815-ee01-4bdb-de6b-042696efbbd0"
      },
      "source": [
        "!pip install torch torchvision"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: torch in /home/jindwang/miniconda3/lib/python3.8/site-packages (1.7.1)\r\n",
            "Requirement already satisfied: torchvision in /home/jindwang/miniconda3/lib/python3.8/site-packages (0.8.2)\r\n",
            "Requirement already satisfied: typing-extensions in /home/jindwang/miniconda3/lib/python3.8/site-packages (from torch) (3.7.4.3)\r\n",
            "Requirement already satisfied: numpy in /home/jindwang/miniconda3/lib/python3.8/site-packages (from torch) (1.19.5)\r\n",
            "Requirement already satisfied: pillow>=4.1.1 in /home/jindwang/miniconda3/lib/python3.8/site-packages (from torchvision) (8.1.0)\r\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qWlAlAeRtm4s"
      },
      "source": [
        "Then, prepare the dataset you need. Here, we wil use the classical **Office-31** dataset.\n",
        "We just need to download it, and then extract it.\n",
        "Skip this step if you already have the data on your disk."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4UCJrNkMtm4s"
      },
      "source": [
        "!wget https://transferlearningdrive.blob.core.windows.net/teamdrive/dataset/office31.zip\n",
        "!unzip office31.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xKVY0mi6tm4t"
      },
      "source": [
        "To verify the dataset, we show its structures."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bpqUe9URtm4t"
      },
      "source": [
        "!apt install tree\n",
        "!tree office31 -d 1"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yyP-9VnQtm4t"
      },
      "source": [
        "## Some imports."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6R098oQTS_TC"
      },
      "source": [
        "import os\n",
        "import torch\n",
        "import torchvision\n",
        "from torchvision import transforms, datasets\n",
        "import torch.nn as nn\n",
        "import time\n",
        "from torchvision import models\n",
        "torch.cuda.set_device(1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K9Ut17hwUDq-"
      },
      "source": [
        "Set the dataset folder, batch size, number of classes, and domain name."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "j4e--IIRU68M"
      },
      "source": [
        "data_folder = 'office31'\n",
        "batch_size = 32\n",
        "n_class = 31\n",
        "domain_src, domain_tar = 'amazon', 'webcam'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XA6e2YaPtm4u"
      },
      "source": [
        "## Data load\n",
        "Now, define a data loader function."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6JtN_bK9VFcM"
      },
      "source": [
        "def load_data(root_path, domain, batch_size, phase):\n",
        "    transform_dict = {\n",
        "        'src': transforms.Compose(\n",
        "        [transforms.RandomResizedCrop(224),\n",
        "         transforms.RandomHorizontalFlip(),\n",
        "         transforms.ToTensor(),\n",
        "         transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
        "                              std=[0.229, 0.224, 0.225]),\n",
        "         ]),\n",
        "        'tar': transforms.Compose(\n",
        "        [transforms.Resize(224),\n",
        "         transforms.ToTensor(),\n",
        "         transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
        "                              std=[0.229, 0.224, 0.225]),\n",
        "         ])}\n",
        "    data = datasets.ImageFolder(root=os.path.join(root_path, domain), transform=transform_dict[phase])\n",
        "    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=phase=='src', drop_last=phase=='tar', num_workers=4)\n",
        "    return data_loader"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PHy09lD9tm4v"
      },
      "source": [
        "Load the data using the above function to test it."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jf_Gw2HRVJM_"
      },
      "source": [
        "src_loader = load_data(data_folder, domain_src, batch_size, phase='src')\n",
        "tar_loader = load_data(data_folder, domain_tar, batch_size, phase='tar')\n",
        "print(f'Source data number: {len(src_loader.dataset)}')\n",
        "print(f'Target data number: {len(tar_loader.dataset)}')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "caKMn438tm4v"
      },
      "source": [
        "## Define the finetune model\n",
        "The model for finetune is based on ResNet-50 for its popularity. Of course you can use other base networks.\n",
        "The main logic of this class is to get the pretrained ResNet-50, use all of its layers but the last one, which we will replace by a new FC layer for classification. Since the original ResNet-50 is for 1000 classes classification, we only need it to classify 31."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OXAjmY7pVK8t"
      },
      "source": [
        "class TransferModel(nn.Module):\n",
        "    def __init__(self,\n",
        "                base_model : str = 'resnet50',\n",
        "                pretrain : bool = True,\n",
        "                n_class : int = 31):\n",
        "        super(TransferModel, self).__init__()\n",
        "        self.base_model = base_model\n",
        "        self.pretrain = pretrain\n",
        "        self.n_class = n_class\n",
        "        if self.base_model == 'resnet50':\n",
        "            self.model = torchvision.models.resnet50(pretrained=True)\n",
        "            n_features = self.model.fc.in_features\n",
        "            fc = torch.nn.Linear(n_features, n_class)\n",
        "            self.model.fc = fc\n",
        "        else:\n",
        "            # Use other models you like, such as vgg or alexnet\n",
        "            pass\n",
        "        self.model.fc.weight.data.normal_(0, 0.005)\n",
        "        self.model.fc.bias.data.fill_(0.1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.model(x)\n",
        "    \n",
        "    def predict(self, x):\n",
        "        return self.forward(x)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UAqT-0jRtm4w"
      },
      "source": [
        "Now, we define a model and test it using a random tensor."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LewRmYIvXEIo"
      },
      "source": [
        "model = TransferModel().cuda()\n",
        "RAND_TENSOR = torch.randn(1, 3, 224, 224).cuda()\n",
        "output = model(RAND_TENSOR)\n",
        "print(output)\n",
        "print(output.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UpX-fuiXtm4w"
      },
      "source": [
        "## Finetune ResNet-50\n",
        "Define some variables. Note that Office-31 doesn't have a validation set, so we use its target domain as the validation set."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0uNeAiPatm4w"
      },
      "source": [
        "Now the most important part: write the finetune function.\n",
        "This function is pretty easy: it is basically a standard classification function. We train it on the 'src' domain, valid it on the 'val' domain, and then test it on the 'tar' domain.\n",
        "The only difference is that Office-31 dataset has no validation set, so we will use the target domain as the validation set. For your own data, you should use its standard validation set.\n",
        "We also set an early_stop variable."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h74gKIVqtm4w"
      },
      "source": [
        "dataloaders = {'src': src_loader,\n",
        "               'val': tar_loader,\n",
        "               'tar': tar_loader}\n",
        "n_epoch = 100\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "early_stop = 20"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-fz_FlAIXsTF"
      },
      "source": [
        "def finetune(model, dataloaders, optimizer):\n",
        "    since = time.time()\n",
        "    best_acc = 0\n",
        "    stop = 0\n",
        "    for epoch in range(0, n_epoch):\n",
        "        stop += 1\n",
        "        # You can uncomment this line for scheduling learning rate\n",
        "        # lr_schedule(optimizer, epoch)\n",
        "        for phase in ['src', 'val', 'tar']:\n",
        "            if phase == 'src':\n",
        "                model.train()\n",
        "            else:\n",
        "                model.eval()\n",
        "            total_loss, correct = 0, 0\n",
        "            for inputs, labels in dataloaders[phase]:\n",
        "                inputs, labels = inputs.cuda(), labels.cuda()\n",
        "                optimizer.zero_grad()\n",
        "                with torch.set_grad_enabled(phase == 'src'):\n",
        "                    outputs = model(inputs)\n",
        "                    loss = criterion(outputs, labels)\n",
        "                preds = torch.max(outputs, 1)[1]\n",
        "                if phase == 'src':\n",
        "                    loss.backward()\n",
        "                    optimizer.step()\n",
        "                total_loss += loss.item() * inputs.size(0)\n",
        "                correct += torch.sum(preds == labels.data)\n",
        "            epoch_loss = total_loss / len(dataloaders[phase].dataset)\n",
        "            epoch_acc = correct.double() / len(dataloaders[phase].dataset)\n",
        "            print(f'Epoch: [{epoch:02d}/{n_epoch:02d}]---{phase}, loss: {epoch_loss:.6f}, acc: {epoch_acc:.4f}')\n",
        "            if phase == 'val' and epoch_acc > best_acc:\n",
        "                stop = 0\n",
        "                best_acc = epoch_acc\n",
        "                torch.save(model.state_dict(), 'model.pkl')\n",
        "        if stop >= early_stop:\n",
        "            break\n",
        "        print()\n",
        "   \n",
        "    time_pass = time.time() - since\n",
        "    print(f'Training complete in {time_pass // 60:.0f}m {time_pass % 60:.0f}s')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rGUZGrm7tm4x"
      },
      "source": [
        "Now, define some train parameters and the optimizer. For simplicity, we use SGD, and the learning rate for the FC layer is 10 times of other layers, which is a common trick."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HGVIu0JXZaGT"
      },
      "source": [
        "param_group = []\n",
        "learning_rate = 0.0001\n",
        "momentum = 5e-4\n",
        "for k, v in model.named_parameters():\n",
        "    if not k.__contains__('fc'):\n",
        "        param_group += [{'params': v, 'lr': learning_rate}]\n",
        "    else:\n",
        "        param_group += [{'params': v, 'lr': learning_rate * 10}]\n",
        "optimizer = torch.optim.SGD(param_group, momentum=momentum)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F7zGR_PFtm4y"
      },
      "source": [
        "## Train and test\n",
        "Now we can train and test it."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uKKrah-AZsHt"
      },
      "source": [
        "finetune(model, dataloaders, optimizer)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GR5Y1x4btm4y"
      },
      "source": [
        "def test(model, target_test_loader):\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    len_target_dataset = len(target_test_loader.dataset)\n",
        "    with torch.no_grad():\n",
        "        for data, target in target_test_loader:\n",
        "            data, target = data.cuda(), target.cuda()\n",
        "            s_output = model.predict(data)\n",
        "            pred = torch.max(s_output, 1)[1]\n",
        "            correct += torch.sum(pred == target)\n",
        "    acc = correct.double() / len(target_test_loader.dataset)\n",
        "    return acc"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1O2CBxmxtm4y"
      },
      "source": [
        "model.load_state_dict(torch.load('model.pkl'))\n",
        "acc_test = test(model, dataloaders['tar'])\n",
        "print(f'Test accuracy: {acc_test}')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I7Iuk_Mitm4z"
      },
      "source": [
        "It's over for finetune. Of course, you should use some learning rate decay trick in real training. But that is not our goal.\n",
        "Next, we will continue to use the same dataloader for domain adaptation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bO4c_QcGtm4z"
      },
      "source": [
        "## Domain adaptation\n",
        "Now we are in domain adaptation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VJwcwLQftm40"
      },
      "source": [
        "## Logic for domain adaptation\n",
        "The logic for domain adaptation is mostly similar to finetune, except that we must add a loss to the finetune model to **regularize the distribution discrepancy** between two domains.\n",
        "Therefore, the most different parts are:\n",
        "- Define some **loss function** to compute the distance (which is the main contribution of most existing DA papers)\n",
        "- Define a new model class to use that loss function for **forward** pass.\n",
        "- Write a slightly different script to train, since we have to take both **source data, source label, and target data**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jy_1xwdJtm40"
      },
      "source": [
        "### Loss function\n",
        "The most popular loss function for DA is **MMD (Maximum Mean Discrepancy)**. For comaprison, we also use another popular loss **CORAL (CORrelation ALignment)**. They are defined as follows."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z3-wKorUtm40"
      },
      "source": [
        "#### MMD loss"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MpQH6VFwtm41"
      },
      "source": [
        "class MMD_loss(nn.Module):\n",
        "    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):\n",
        "        super(MMD_loss, self).__init__()\n",
        "        self.kernel_num = kernel_num\n",
        "        self.kernel_mul = kernel_mul\n",
        "        self.fix_sigma = None\n",
        "        self.kernel_type = kernel_type\n",
        "\n",
        "    def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):\n",
        "        n_samples = int(source.size()[0]) + int(target.size()[0])\n",
        "        total = torch.cat([source, target], dim=0)\n",
        "        total0 = total.unsqueeze(0).expand(\n",
        "            int(total.size(0)), int(total.size(0)), int(total.size(1)))\n",
        "        total1 = total.unsqueeze(1).expand(\n",
        "            int(total.size(0)), int(total.size(0)), int(total.size(1)))\n",
        "        L2_distance = ((total0-total1)**2).sum(2)\n",
        "        if fix_sigma:\n",
        "            bandwidth = fix_sigma\n",
        "        else:\n",
        "            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)\n",
        "        bandwidth /= kernel_mul ** (kernel_num // 2)\n",
        "        bandwidth_list = [bandwidth * (kernel_mul**i)\n",
        "                          for i in range(kernel_num)]\n",
        "        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)\n",
        "                      for bandwidth_temp in bandwidth_list]\n",
        "        return sum(kernel_val)\n",
        "\n",
        "    def linear_mmd2(self, f_of_X, f_of_Y):\n",
        "        loss = 0.0\n",
        "        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)\n",
        "        loss = delta.dot(delta.T)\n",
        "        return loss\n",
        "\n",
        "    def forward(self, source, target):\n",
        "        if self.kernel_type == 'linear':\n",
        "            return self.linear_mmd2(source, target)\n",
        "        elif self.kernel_type == 'rbf':\n",
        "            batch_size = int(source.size()[0])\n",
        "            kernels = self.guassian_kernel(\n",
        "                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)\n",
        "            XX = torch.mean(kernels[:batch_size, :batch_size])\n",
        "            YY = torch.mean(kernels[batch_size:, batch_size:])\n",
        "            XY = torch.mean(kernels[:batch_size, batch_size:])\n",
        "            YX = torch.mean(kernels[batch_size:, :batch_size])\n",
        "            loss = torch.mean(XX + YY - XY - YX)\n",
        "            return loss\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NcfUy_2Dtm41"
      },
      "source": [
        "#### CORAL loss"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uZhKJq15tm41"
      },
      "source": [
        "def CORAL(source, target):\n",
        "    d = source.size(1)\n",
        "    ns, nt = source.size(0), target.size(0)\n",
        "\n",
        "    # source covariance\n",
        "    tmp_s = torch.ones((1, ns)).cuda() @ source\n",
        "    cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)\n",
        "\n",
        "    # target covariance\n",
        "    tmp_t = torch.ones((1, nt)).cuda() @ target\n",
        "    ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)\n",
        "\n",
        "    # frobenius norm\n",
        "    loss = (cs - ct).pow(2).sum().sqrt()\n",
        "    loss = loss / (4 * d * d)\n",
        "\n",
        "    return loss"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VB2cDp8Gtm41"
      },
      "source": [
        "### Model\n",
        "Now we use ResNet-50 again just like finetune. The difference is that we rewrite the ResNet-50 class to drop its last layer."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UOLx_OSxtm41"
      },
      "source": [
        "from torchvision import models\n",
        "class ResNet50Fc(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(ResNet50Fc, self).__init__()\n",
        "        model_resnet50 = models.resnet50(pretrained=True)\n",
        "        self.conv1 = model_resnet50.conv1\n",
        "        self.bn1 = model_resnet50.bn1\n",
        "        self.relu = model_resnet50.relu\n",
        "        self.maxpool = model_resnet50.maxpool\n",
        "        self.layer1 = model_resnet50.layer1\n",
        "        self.layer2 = model_resnet50.layer2\n",
        "        self.layer3 = model_resnet50.layer3\n",
        "        self.layer4 = model_resnet50.layer4\n",
        "        self.avgpool = model_resnet50.avgpool\n",
        "        self.__in_features = model_resnet50.fc.in_features\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.conv1(x)\n",
        "        x = self.bn1(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.maxpool(x)\n",
        "        x = self.layer1(x)\n",
        "        x = self.layer2(x)\n",
        "        x = self.layer3(x)\n",
        "        x = self.layer4(x)\n",
        "        x = self.avgpool(x)\n",
        "        x = x.view(x.size(0), -1)\n",
        "        return x\n",
        "\n",
        "    def output_num(self):\n",
        "        return self.__in_features"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IRgdNM3Wtm42"
      },
      "source": [
        "Now the main class for DA. We take ResNet-50 as its backbone, add a bottleneck layer and our own FC layer for classification.\n",
        "Note the `adapt_loss` function. It is just using our predefined MMD or CORAL loss. Of course you can use your own loss."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oC5NKJpJtm42"
      },
      "source": [
        "class TransferNet(nn.Module):\n",
        "    def __init__(self,\n",
        "                 num_class, \n",
        "                 base_net='resnet50', \n",
        "                 transfer_loss='mmd', \n",
        "                 use_bottleneck=True, \n",
        "                 bottleneck_width=256, \n",
        "                 width=1024):\n",
        "        super(TransferNet, self).__init__()\n",
        "        if base_net == 'resnet50':\n",
        "            self.base_network = ResNet50Fc()\n",
        "        else:\n",
        "            # Your own basenet\n",
        "            return\n",
        "        self.use_bottleneck = use_bottleneck\n",
        "        self.transfer_loss = transfer_loss\n",
        "        bottleneck_list = [nn.Linear(self.base_network.output_num(\n",
        "        ), bottleneck_width), nn.BatchNorm1d(bottleneck_width), nn.ReLU(), nn.Dropout(0.5)]\n",
        "        self.bottleneck_layer = nn.Sequential(*bottleneck_list)\n",
        "        classifier_layer_list = [nn.Linear(self.base_network.output_num(), width), nn.ReLU(), nn.Dropout(0.5),\n",
        "                                 nn.Linear(width, num_class)]\n",
        "        self.classifier_layer = nn.Sequential(*classifier_layer_list)\n",
        "\n",
        "        self.bottleneck_layer[0].weight.data.normal_(0, 0.005)\n",
        "        self.bottleneck_layer[0].bias.data.fill_(0.1)\n",
        "        for i in range(2):\n",
        "            self.classifier_layer[i * 3].weight.data.normal_(0, 0.01)\n",
        "            self.classifier_layer[i * 3].bias.data.fill_(0.0)\n",
        "\n",
        "    def forward(self, source, target):\n",
        "        source = self.base_network(source)\n",
        "        target = self.base_network(target)\n",
        "        source_clf = self.classifier_layer(source)\n",
        "        if self.use_bottleneck:\n",
        "            source = self.bottleneck_layer(source)\n",
        "            target = self.bottleneck_layer(target)\n",
        "        transfer_loss = self.adapt_loss(source, target, self.transfer_loss)\n",
        "        return source_clf, transfer_loss\n",
        "\n",
        "    def predict(self, x):\n",
        "        features = self.base_network(x)\n",
        "        clf = self.classifier_layer(features)\n",
        "        return clf\n",
        "\n",
        "    def adapt_loss(self, X, Y, adapt_loss):\n",
        "        \"\"\"Compute adaptation loss, currently we support mmd and coral\n",
        "\n",
        "        Arguments:\n",
        "            X {tensor} -- source matrix\n",
        "            Y {tensor} -- target matrix\n",
        "            adapt_loss {string} -- loss type, 'mmd' or 'coral'. You can add your own loss\n",
        "\n",
        "        Returns:\n",
        "            [tensor] -- adaptation loss tensor\n",
        "        \"\"\"\n",
        "        if adapt_loss == 'mmd':\n",
        "            mmd_loss = MMD_loss()\n",
        "            loss = mmd_loss(X, Y)\n",
        "        elif adapt_loss == 'coral':\n",
        "            loss = CORAL(X, Y)\n",
        "        else:\n",
        "            # Your own loss\n",
        "            loss = 0\n",
        "        return loss"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hAdPs27btm42"
      },
      "source": [
        "### Train\n",
        "Now the train part."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OK6P8uMDtm42"
      },
      "source": [
        "transfer_loss = 'mmd'\n",
        "learning_rate = 0.0001\n",
        "transfer_model = TransferNet(n_class, transfer_loss=transfer_loss, base_net='resnet50').cuda()\n",
        "optimizer = torch.optim.SGD([\n",
        "    {'params': transfer_model.base_network.parameters()},\n",
        "    {'params': transfer_model.bottleneck_layer.parameters(), 'lr': 10 * learning_rate},\n",
        "    {'params': transfer_model.classifier_layer.parameters(), 'lr': 10 * learning_rate},\n",
        "], lr=learning_rate, momentum=0.9, weight_decay=5e-4)\n",
        "lamb = 10 # weight for transfer loss, it is a hyperparameter that needs to be tuned"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4WpUfcHItm42"
      },
      "source": [
        "The main train function. Since we have to enumerate all source and target samples, we have to use `zip` operation to enumerate each pair of these two domains. It is common that two domains have different sizes, but we think by randomly sampling them in many epochs, we may sample each one of them."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AGlVNI2ktm42"
      },
      "source": [
        "def train(dataloaders, model, optimizer):\n",
        "    source_loader, target_train_loader, target_test_loader = dataloaders['src'], dataloaders['val'], dataloaders['tar']\n",
        "    len_source_loader = len(source_loader)\n",
        "    len_target_loader = len(target_train_loader)\n",
        "    best_acc = 0\n",
        "    stop = 0\n",
        "    n_batch = min(len_source_loader, len_target_loader)\n",
        "    for e in range(n_epoch):\n",
        "        stop += 1\n",
        "        train_loss_clf, train_loss_transfer, train_loss_total = 0, 0, 0\n",
        "        model.train()\n",
        "        for (src, tar) in zip(source_loader, target_train_loader):\n",
        "            data_source, label_source = src\n",
        "            data_target, _ = tar\n",
        "            data_source, label_source = data_source.cuda(), label_source.cuda()\n",
        "            data_target = data_target.cuda()\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            label_source_pred, transfer_loss = model(data_source, data_target)\n",
        "            clf_loss = criterion(label_source_pred, label_source)\n",
        "            loss = clf_loss + lamb * transfer_loss\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            train_loss_clf = clf_loss.detach().item() + train_loss_clf\n",
        "            train_loss_transfer = transfer_loss.detach().item() + train_loss_transfer\n",
        "            train_loss_total = loss.detach().item() + train_loss_total\n",
        "        acc = test(model, target_test_loader)\n",
        "        print(f'Epoch: [{e:2d}/{n_epoch}], cls_loss: {train_loss_clf/n_batch:.4f}, transfer_loss: {train_loss_transfer/n_batch:.4f}, total_Loss: {train_loss_total/n_batch:.4f}, acc: {acc:.4f}')\n",
        "        if best_acc < acc:\n",
        "            best_acc = acc\n",
        "            torch.save(model.state_dict(), 'trans_model.pkl')\n",
        "            stop = 0\n",
        "        if stop >= early_stop:\n",
        "            break"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "scrolled": true,
        "id": "nqSCG6-Xtm43"
      },
      "source": [
        "train(dataloaders, transfer_model, optimizer)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yWvYODRDtm43"
      },
      "source": [
        "transfer_model.load_state_dict(torch.load('trans_model.pkl'))\n",
        "acc_test = test(transfer_model, dataloaders['tar'])\n",
        "print(f'Test accuracy: {acc_test}')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jbw6KndNtm43"
      },
      "source": [
        "Now we are done."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JQX2ak-jtm43"
      },
      "source": [
        "You see, we don't even need to install a library or package to train a domain adaptation or finetune model.\n",
        "In your own work, you can also use this notebook to test your own algorithms."
      ]
    }
  ]
}